"""coding=utf-8
Functions and classes related to optimization (weight updates).

Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
Copyright 2018 The Google AI Language Team Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import re
import tensorflow as tf
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops

# from npu_bridge.estimator.npu.npu_optimizer import NPUDistributedOptimizer
# from npu_bridge.estimator.npu import npu_loss_scale_optimizer as lso
# from npu_bridge.estimator.npu import npu_loss_scale_manager as lsm_lib


def create_optimizer(loss,
                     init_lr,
                     num_train_steps,
                     num_warmup_steps,
                     hvd=None,
                     manual_fp16=False,
                     use_fp16=False,
                     num_accumulation_steps=1,
                     optimizer_type="adam"):
    """Creates an optimizer training op."""
    global_step = tf.train.get_or_create_global_step()

    # avoid step change in learning rate at end of warmup phase
    if optimizer_type == "adam":
        power = 1.0
        decayed_learning_rate_at_crossover_point = init_lr * (
            (1.0 - float(num_warmup_steps) / float(num_train_steps))**power)
    else:
        power = 0.5
        decayed_learning_rate_at_crossover_point = init_lr

    adjusted_init_lr = init_lr * (init_lr /
                                  decayed_learning_rate_at_crossover_point)
    print(
        'decayed_learning_rate_at_crossover_point = %e, adjusted_init_lr = %e'
        % (decayed_learning_rate_at_crossover_point, adjusted_init_lr))

    learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)

    # Implements linear decay of the learning rate.
    learning_rate = tf.train.polynomial_decay(learning_rate,
                                              global_step,
                                              num_train_steps,
                                              end_learning_rate=0.0,
                                              power=power,
                                              cycle=False)

    # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
    # learning rate will be `global_step/num_warmup_steps * init_lr`.
    if num_warmup_steps:
        global_steps_int = tf.cast(global_step, tf.int32)
        warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)

        global_steps_float = tf.cast(global_steps_int, tf.float32)
        warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)

        warmup_percent_done = global_steps_float / warmup_steps_float
        warmup_learning_rate = init_lr * warmup_percent_done

        is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
        learning_rate = ((1.0 - is_warmup) * learning_rate +
                         is_warmup * warmup_learning_rate)

    if optimizer_type == "lamb":
        print("Initializing LAMB Optimizer")
        optimizer = LAMBOptimizer(
            learning_rate=learning_rate,
            weight_decay_rate=0.01,
            beta_1=0.9,
            beta_2=0.999,
            epsilon=1e-6,
            exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
    else:
        print("Initializing ADAM Weight Decay Optimizer")
        # It is recommended that you use this optimizer for fine tuning, since this
        # is how the model was trained (note that the Adam m/v variables are NOT
        # loaded from init_checkpoint.)
        optimizer = AdamWeightDecayOptimizer(
            learning_rate=learning_rate,
            weight_decay_rate=0.01,
            beta_1=0.9,
            beta_2=0.999,
            epsilon=1e-4,
            exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])

    tvars = tf.trainable_variables()
    grads_and_vars = optimizer.compute_gradients(
        loss * 1.0 / num_accumulation_steps, tvars)

    if num_accumulation_steps > 1:
        local_step = tf.get_variable(name="local_step",
                                     shape=[],
                                     dtype=tf.int32,
                                     trainable=False,
                                     initializer=tf.zeros_initializer)
        batch_finite = tf.get_variable(name="batch_finite",
                                       shape=[],
                                       dtype=tf.bool,
                                       trainable=False,
                                       initializer=tf.ones_initializer)
        accum_vars = []
        for tvar in tf.trainable_variables():
            _name = tvar.name.split(":")[0] + "/accum"
            accum_vars.append(tf.get_variable(name=_name,
                                            shape=tvar.shape.as_list(),
                                            dtype=tf.float32,
                                            trainable=False,
                                            initializer=tf.zeros_initializer()))

        reset_step = tf.cast(tf.math.equal(local_step % num_accumulation_steps,
                                           0),
                             dtype=tf.bool)
        local_step = tf.cond(
            reset_step, lambda: local_step.assign(tf.ones_like(local_step)),
            lambda: local_step.assign_add(1))

        update_step = tf.identity(tf.cast(tf.math.equal(
            local_step % num_accumulation_steps, 0),
                                          dtype=tf.bool),
                                  name="update_step")
        # update_op = tf.cond(update_step,
        #                     lambda: update(accum_vars), lambda: tf.no_op())

        new_global_step = tf.cond(
            tf.math.logical_and(
                update_step,
                tf.cast(hvd.allreduce(tf.cast(batch_finite, tf.int32)),
                        tf.bool)), lambda: global_step + 1,
            lambda: global_step)
        new_global_step = tf.identity(new_global_step, name='step_update')
        # train_op = tf.group(update_op, [global_step.assign(new_global_step)])
    else:
        grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None]
        grads, tvars = list(zip(*grads_and_vars))

        if tf.flags.FLAGS.npu_bert_clip_by_global_norm:
            all_are_finite = tf.reduce_all([
                tf.reduce_all(tf.is_finite(g)) for g in grads
            ]) if (tf.flags.FLAGS.npu_bert_loss_scale not in [None, -1]) and (
                use_fp16 or manual_fp16) else tf.constant(True, dtype=tf.bool)

        # This is how the model was pre-trained.
        # ensure global norm is a finite number
        # to prevent clip_by_global_norm from having a hizzy fit.
        # if tf.flags.FLAGS.npu_bert_clip_by_global_norm:
        #   (clipped_grads, _) = tf.clip_by_global_norm(
        #     grads, clip_norm=1.0,
        #     use_norm=tf.cond(
        #         all_are_finite,
        #         lambda: tf.global_norm(grads),
        #         lambda: tf.constant(1.0)))
        # else:
        #   with tf.name_scope("clip_grads"):
        #     clipped_grads = [
        #       (tf.clip_by_norm(grad, clip_norm=1.0))
        #       if grad is not None else (grad, var) for grad in grads
        #     ]

        with tf.name_scope("apply_grads"):
            train_op = optimizer.apply_gradients(list(zip(grads, tvars)),
                                                 global_step=global_step)

        if tf.flags.FLAGS.npu_bert_clip_by_global_norm:
            new_global_step = tf.cond(all_are_finite, lambda: global_step + 1,
                                      lambda: global_step)
        else:
            new_global_step = global_step + 1
        new_global_step = tf.identity(new_global_step, name='step_update')
        train_op = tf.group(train_op, [global_step.assign(new_global_step)])
    return train_op


class AdamWeightDecayOptimizer(tf.train.Optimizer):
    """A basic Adam optimizer that includes "correct" L2 weight decay."""
    def __init__(self,
                 learning_rate,
                 weight_decay_rate=0.0,
                 beta_1=0.9,
                 beta_2=0.999,
                 epsilon=1e-4,
                 exclude_from_weight_decay=None,
                 name="AdamWeightDecayOptimizer"):
        """Constructs a AdamWeightDecayOptimizer."""
        super(AdamWeightDecayOptimizer, self).__init__(False, name)

        self.learning_rate = tf.identity(learning_rate, name='learning_rate')
        self.weight_decay_rate = weight_decay_rate
        self.beta_1 = beta_1
        self.beta_2 = beta_2
        self.epsilon = epsilon
        self.exclude_from_weight_decay = exclude_from_weight_decay

    def apply_gradients(self,
                        grads_and_vars,
                        global_step=None,
                        name=None,
                        manual_fp16=False):
        """See base class."""
        assignments = []
        for (grad, param) in grads_and_vars:
            with tf.name_scope("apply_one_adam"):
                if grad is None or param is None:
                    continue

                param_name = self._get_variable_name(param.name)
                has_shadow = manual_fp16 and param.dtype.base_dtype != tf.float32
                if has_shadow:
                    # create shadow fp32 weights for fp16 variable
                    param_fp32 = tf.get_variable(name=param_name + "/shadow",
                                                 dtype=tf.float32,
                                                 trainable=False,
                                                 initializer=tf.cast(
                                                     param.initialized_value(),
                                                     tf.float32))
                else:
                    param_fp32 = param

                m = tf.get_variable(name=param_name + "/adam_m",
                                    shape=param.shape.as_list(),
                                    dtype=tf.float32,
                                    trainable=False,
                                    initializer=tf.zeros_initializer())
                v = tf.get_variable(name=param_name + "/adam_v",
                                    shape=param.shape.as_list(),
                                    dtype=tf.float32,
                                    trainable=False,
                                    initializer=tf.zeros_initializer())

                # Standard Adam update.
                next_m = (tf.multiply(self.beta_1, m) +
                          tf.multiply(1.0 - self.beta_1, grad))
                next_v = (tf.multiply(self.beta_2, v) +
                          tf.multiply(1.0 - self.beta_2, tf.square(grad)))

                update = next_m / (tf.sqrt(next_v) + self.epsilon)

                # Just adding the square of the weights to the loss function is *not*
                # the correct way of using L2 regularization/weight decay with Adam,
                # since that will interact with the m and v parameters in strange ways.
                #
                # Instead we want to decay the weights in a manner that doesn't interact
                # with the m/v parameters. This is equivalent to adding the square
                # of the weights to the loss with plain (non-momentum) SGD.
                if self._do_use_weight_decay(param_name):
                    update += self.weight_decay_rate * param_fp32

                update_with_lr = self.learning_rate * update

                next_param = param_fp32 - update_with_lr

                if has_shadow:
                    # cast shadow fp32 weights to fp16 and assign to trainable variable
                    param.assign(tf.cast(next_param, param.dtype.base_dtype))
                assignments.extend([
                    param_fp32.assign(next_param),
                    m.assign(next_m),
                    v.assign(next_v)
                ])
        new_global_step = global_step + 1
        new_global_step = tf.identity(new_global_step, name='step_update')
        assignments.extend([global_step.assign(new_global_step)])
        return tf.group(*assignments, name=name)

    def _do_use_weight_decay(self, param_name):
        """Whether to use L2 weight decay for `param_name`."""
        if not self.weight_decay_rate:
            return False
        if self.exclude_from_weight_decay:
            for r in self.exclude_from_weight_decay:
                if re.search(r, param_name) is not None:
                    return False
        return True

    def _get_variable_name(self, param_name):
        """Get the variable name from the tensor name."""
        m = re.match("^(.*):\\d+$", param_name)
        if m is not None:
            param_name = m.group(1)
        return param_name


class LAMBOptimizer(tf.train.Optimizer):
    """A LAMB optimizer that includes "correct" L2 weight decay."""
    def __init__(self,
                 learning_rate,
                 weight_decay_rate=0.0,
                 beta_1=0.9,
                 beta_2=0.999,
                 epsilon=1e-6,
                 exclude_from_weight_decay=None,
                 name="LAMBOptimizer"):
        """Constructs a LAMBOptimizer."""
        super(LAMBOptimizer, self).__init__(False, name)

        self.learning_rate = tf.identity(learning_rate, name='learning_rate')
        self.weight_decay_rate = weight_decay_rate
        self.beta_1 = beta_1
        self.beta_2 = beta_2
        self.epsilon = epsilon
        self.exclude_from_weight_decay = exclude_from_weight_decay
        self.steps = 0

    def apply_gradients(self,
                        grads_and_vars,
                        global_step=None,
                        name=None,
                        manual_fp16=False):
        """See base class."""
        assignments = []
        for (grad, param) in grads_and_vars:
            with tf.name_scope("apply_one_lamb"):
                if grad is None or param is None:
                    continue

                param_name = self._get_variable_name(param.name)
                has_shadow = manual_fp16 and param.dtype.base_dtype != tf.float32
                if has_shadow:
                    # create shadow fp32 weights for fp16 variable
                    param_fp32 = tf.get_variable(name=param_name + "/shadow",
                                                 dtype=tf.float32,
                                                 trainable=False,
                                                 initializer=tf.cast(
                                                     param.initialized_value(),
                                                     tf.float32))
                else:
                    param_fp32 = param

                m = tf.get_variable(name=param_name + "/adam_m",
                                    shape=param.shape.as_list(),
                                    dtype=tf.float32,
                                    trainable=False,
                                    initializer=tf.zeros_initializer())
                v = tf.get_variable(name=param_name + "/adam_v",
                                    shape=param.shape.as_list(),
                                    dtype=tf.float32,
                                    trainable=False,
                                    initializer=tf.zeros_initializer())

                # LAMB update
                next_m = (tf.multiply(self.beta_1, m) +
                          tf.multiply(1.0 - self.beta_1, grad))
                next_v = (tf.multiply(self.beta_2, v) +
                          tf.multiply(1.0 - self.beta_2, tf.square(grad)))

                self.steps += 1
                beta1_correction = (1 - self.beta_1 ** self.steps)
                beta2_correction = (1 - self.beta_2 ** self.steps)

                next_m_unbiased = next_m / beta1_correction
                next_v_unbiased = next_v / beta2_correction

                update = next_m_unbiased / (tf.sqrt(next_v_unbiased) +
                                            self.epsilon)

                # Just adding the square of the weights to the loss function is *not*
                # the correct way of using L2 regularization/weight decay with Adam,
                # since that will interact with the m and v parameters in strange ways.
                #
                # Instead we want to decay the weights in a manner that doesn't interact
                # with the m/v parameters. This is equivalent to adding the square
                # of the weights to the loss with plain (non-momentum) SGD.
                if self._do_use_weight_decay(param_name):
                    update += self.weight_decay_rate * param_fp32

                w_norm = linalg_ops.norm(param, ord=2)
                g_norm = linalg_ops.norm(update, ord=2)
                ratio = array_ops.where(
                    math_ops.greater(w_norm, 0),
                    array_ops.where(math_ops.greater(g_norm, 0),
                                    (w_norm / g_norm), 1.0), 1.0)

                update_with_lr = ratio * self.learning_rate * update

                next_param = param_fp32 - update_with_lr

                if has_shadow:
                    # cast shadow fp32 weights to fp16 and assign to trainable variable
                    param.assign(tf.cast(next_param, param.dtype.base_dtype))
                assignments.extend([
                    param_fp32.assign(next_param),
                    m.assign(next_m),
                    v.assign(next_v)
                ])
        new_global_step = global_step + 1
        new_global_step = tf.identity(new_global_step, name='step_update')
        assignments.extend([global_step.assign(new_global_step)])
        return tf.group(*assignments, name=name)

    def _do_use_weight_decay(self, param_name):
        """Whether to use L2 weight decay for `param_name`."""
        if not self.weight_decay_rate:
            return False
        if self.exclude_from_weight_decay:
            for r in self.exclude_from_weight_decay:
                if re.search(r, param_name) is not None:
                    return False
        return True

    def _get_variable_name(self, param_name):
        """Get the variable name from the tensor name."""
        m = re.match("^(.*):\\d+$", param_name)
        if m is not None:
            param_name = m.group(1)
        return param_name
